import dgl
import torch as th

from ogb.nodeproppred import DglNodePropPredDataset
import os


def load_citeseer():
    from dgl.data import CiteseerGraphDataset

    # load citeseer data
    dataset = CiteseerGraphDataset()
    graph = dataset[0]
    
    graph.ndata['features'] = graph.ndata['feat']
    del graph.ndata['feat']
    graph.ndata['labels'] = graph.ndata['label']

    num_labels = len(th.unique(graph.ndata['label'][th.logical_not(th.isnan(graph.ndata['label']))]))

    return graph, num_labels

def load_flickr():
    from dgl.data import FlickrDataset

    # load flickr data
    dataset = FlickrDataset()
    graph = dataset[0]
    
    graph.ndata['features'] = graph.ndata['feat']
    del graph.ndata['feat']
    graph.ndata['labels'] = graph.ndata['label']

    num_labels = len(th.unique(graph.ndata['label'][th.logical_not(th.isnan(graph.ndata['label']))]))

    return graph, num_labels

def load_reddit():
    from dgl.data import RedditDataset

    # load reddit data
    dataset = RedditDataset()
    graph = dataset[0]
    
    graph.ndata['features'] = graph.ndata['feat']
    del graph.ndata['feat']
    graph.ndata['labels'] = graph.ndata['label']

    num_labels = len(th.unique(graph.ndata['label'][th.logical_not(th.isnan(graph.ndata['label']))]))

    return graph, num_labels

def load_ogb(name):
    print('load', name)
    data = DglNodePropPredDataset(name=name)
    print('finish loading', name)
    splitted_idx = data.get_idx_split()
    graph, labels = data[0]
    labels = labels[:, 0]

    graph.ndata['features'] = graph.ndata['feat']
    del graph.ndata['feat']
    graph.ndata['labels'] = labels
    in_feats = graph.ndata['features'].shape[1]
    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))

    # Find the node IDs in the training, validation, and test set.
    train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
    train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    train_mask[train_nid] = True
    val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    val_mask[val_nid] = True
    test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    test_mask[test_nid] = True
    graph.ndata['train_mask'] = train_mask
    graph.ndata['val_mask'] = val_mask
    graph.ndata['test_mask'] = test_mask
    print('finish constructing', name)
    return graph, num_labels

def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
    train_g = g.subgraph(g.ndata['train_mask'])
    val_g = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask'])
    test_g = g
    return train_g, val_g, test_g

def load_partitioned_graphs(path_to_partitioned_dataset):
    """
    Load the partitioned graphs from the given path.
    
    Args:
        path_to_partitioned_dataset (str): The path to the partitioned dataset.
        
    Returns:
        list: A list of partitioned graphs.
    """

    print("Loading partitioned dataset from", path_to_partitioned_dataset)

    # Check if the partitioned dataset file already exists
    if not os.path.isfile(path_to_partitioned_dataset):
        raise FileNotFoundError("Partitioned dataset file not found")

    partitioned_graphs = dgl.load_graphs(path_to_partitioned_dataset)[0]
    print("finish loading partitioned dataset")

    return partitioned_graphs

def load_complete_graphs(dataset):
    """
    Load the complete graph from the given dataset.
    
    Args:
        dataset (str): The name of the dataset.
        
    Returns:
        dgl.DGLGraph: The complete graph.
    """
    print("Loading complete graph for", dataset)
    
    if dataset == 'ogbn-products':
        complete_graph, _ = load_ogb('ogbn-products')
    elif dataset == 'reddit':
        complete_graph, _ = load_reddit()
    elif dataset == 'citeseer':
        complete_graph, _ = load_citeseer()
    elif dataset == 'flickr':
        complete_graph, _ = load_flickr()
    elif dataset == 'ogbn-arxiv':
        complete_graph, _ = load_ogb('ogbn-arxiv')
    else:
        raise ValueError("Invalid dataset")

    return complete_graph
